{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:37.735658Z",
     "start_time": "2025-09-03T08:53:37.530087Z"
    }
   },
   "source": [
    "import gymnasium as gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.8854055 , -0.46481943,  0.8836212 ], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:38.421838Z",
     "start_time": "2025-09-03T08:53:37.904745Z"
    }
   },
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIk1JREFUeJzt3Q1wVOW9x/H/5v09IYEkIImkQgXKi/IipNarlUi0qCB4ax3GRsvVCyLDi8NUrOLoOBMG760vraJ3eivMHQWbXkFBoKYBQpXwFkQhQKQVIQJJILh5g2xe9tx5ns7uzUKAACfZZzffz8zxZM958uTsMdkfzzn/c47DsixLAAAwUIi/NwAAgIshpAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMbyW0i9+eabMmDAAImKipJx48bJzp07/bUpAABD+SWkPvjgA1mwYIG88MILsmfPHhk5cqTk5uZKdXW1PzYHAGAohz9uMKtGTmPHjpXf//73+rXb7ZaMjAyZM2eOPPPMM929OQAAQ4V19w9sbm6W0tJSWbRokXdZSEiI5OTkSElJSYff43K59OShQu3MmTOSkpIiDoejW7YbAGAfNT6qr6+Xfv366QwwJqROnz4tbW1tkpaW5rNcvT506FCH35Ofny8vvvhiN20hAKC7VFRUSP/+/c0JqauhRl3qHJZHbW2tZGZm6jeXkJDg120DAFy5uro6fZonPj7+ku26PaR69+4toaGhUlVV5bNcvU5PT+/weyIjI/V0PhVQhBQABK7LnbLp9uq+iIgIGT16tBQVFfmcY1Kvs7Ozu3tzAAAG88vhPnXoLi8vT8aMGSO33HKLvPbaa9LY2CiPPfaYPzYHAGAov4TUQw89JKdOnZLFixdLZWWl3HTTTbJx48YLiikAAD2bX66TsuOEW2Jioi6g4JwUAASezn6Oc+8+AICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQCAsQgpAICxCCkAgLEIKQBA8ITU1q1b5b777pN+/fqJw+GQNWvW+Ky3LEsWL14sffv2lejoaMnJyZHDhw/7tDlz5oxMnz5dEhISJCkpSWbMmCENDQ3X/m4AAD07pBobG2XkyJHy5ptvdrh+6dKl8sYbb8jbb78tO3bskNjYWMnNzZWmpiZvGxVQZWVlUlhYKOvWrdPB98QTT1zbOwEABB/rGqhvX716tfe12+220tPTrVdeecW7zOl0WpGRkdbKlSv16wMHDujv27Vrl7fNhg0bLIfDYR0/frxTP7e2tlb3oeYAgMDT2c9xW89JHTlyRCorK/UhPo/ExEQZN26clJSU6Ndqrg7xjRkzxttGtQ8JCdEjr464XC6pq6vzmQAAwc/WkFIBpaSlpfksV68969Q8NTXVZ31YWJgkJyd725wvPz9fh51nysjIsHOzAQCGCojqvkWLFkltba13qqio8PcmAQACLaTS09P1vKqqyme5eu1Zp+bV1dU+61tbW3XFn6fN+SIjI3UlYPsJABD8bA2prKwsHTRFRUXeZer8kTrXlJ2drV+rudPplNLSUm+bTZs2idvt1ueuAADwCJMrpK5n+vvf/+5TLLF37159TikzM1PmzZsnL7/8sgwaNEiH1vPPP6+vqZoyZYpuP2TIELn77rvl8ccf12XqLS0t8tRTT8kvfvEL3Q4AAK8rLRvcvHmzLhs8f8rLy/OWoT///PNWWlqaLj2fMGGCVV5e7tNHTU2N9fDDD1txcXFWQkKC9dhjj1n19fW2ly4CAMzU2c9xh/qPBBh1CFFV+akiCs5PAUDg6ezneEBU9wEAeiZCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKAGAsQgoAYCxCCgBgLEIKABAcIZWfny9jx46V+Ph4SU1NlSlTpkh5eblPm6amJpk9e7akpKRIXFycTJs2TaqqqnzaHDt2TCZNmiQxMTG6n4ULF0pra6s97wgA0DNDqri4WAfQ9u3bpbCwUFpaWmTixInS2NjobTN//nxZu3atFBQU6PYnTpyQqVOnete3tbXpgGpubpZt27bJihUrZPny5bJ48WJ73xkAIPBZ16C6utpSXRQXF+vXTqfTCg8PtwoKCrxtDh48qNuUlJTo1+vXr7dCQkKsyspKb5tly5ZZCQkJlsvl6tTPra2t1X2qOQAg8HT2c/yazknV1tbqeXJysp6Xlpbq0VVOTo63zeDBgyUzM1NKSkr0azUfPny4pKWledvk5uZKXV2dlJWVdfhzXC6XXt9+AgAEv6sOKbfbLfPmzZNbb71Vhg0bppdVVlZKRESEJCUl+bRVgaTWedq0DyjPes+6i50LS0xM9E4ZGRlXu9kAgJ4QUurc1P79+2XVqlXS1RYtWqRHbZ6poqKiy38mAMD/wq7mm5566ilZt26dbN26Vfr37+9dnp6ergsinE6nz2hKVfepdZ42O3fu9OnPU/3naXO+yMhIPQEAepYrGklZlqUDavXq1bJp0ybJysryWT969GgJDw+XoqIi7zJVoq5KzrOzs/VrNd+3b59UV1d726hKwYSEBBk6dOi1vyMAQM8cSalDfO+//7589NFH+lopzzkkdZ4oOjpaz2fMmCELFizQxRQqeObMmaODafz48bqtKllXYfTII4/I0qVLdR/PPfec7pvREgCgPYcq8ZNOcjgcHS5/99135dFHH/VezPv000/LypUrdVWeqtx76623fA7lHT16VGbNmiVbtmyR2NhYycvLkyVLlkhYWOcyU1X3qUBU56dUEAIAAktnP8evKKRMQUgBQGDr7Oc49+4DABiLkAIAGIuQAgAYi5ACABiLkAIAGIuQAgAYi5ACABiLkAIAGIuQAgAYi5ACABiLkAIAGIuQAgAYi5ACABiLkAIAGIuQAgAYi5ACABiLkAIAGIuQAgAYi5ACABgrzN8bAOD/WZZ10XUOh6NbtwUwASEFGMBqa5PW+nqp27NHnLt2SVNFhbSdOydhCQkSO3Cg9PrJTyTmhhskNDaWsEKPQkgBfuZ2ucS5fbtUrV0rZw8fVsMp77qWU6fk3D/+ITWbN0viqFGSOmWKxA0ZQlChxyCkAD9Sh/dOffqpVBYUSKvTefF2zc06yJpOnpTMJ56QuGHDCCr0CBROAH48xFfz17/Kiffeu2RAtdd09Kgc+6//koaDBy95/goIFoQU4CeNX3+tR1Dus2d9lh9vbJR1FRWy8ptv5K8nTkhjS8sFQXVy5Uppa2jo5i0Guh+H+wA/cLe0SO3u3eKqrPQuUyOjIw0N8sIXX8i3DQ3S1NYmCeHhMqxXL/mPsWMlPOT//01Z/+WXcmzZMslauJDDfghqjKQAP2ipqZGqDz/0WfZNQ4M8/vnncrC2Vs61tYk6mFfb0iKfV1fL3B07pKapyad9/f794jp5spu3HOhehBTgB2rUpM5JtfdaWZkOpY7sPH1aCk+c8FmmzmNV/ulPYrndXbqtgD8RUkAAqy0t1Yf+KKJAsCKkgADWWlsrtXv2iHWRERgQ6AgpwBCTMjIk/CJFEAPi4mREcnKH66o/+kif4wKCESEF+EF4YqIk//SnPsty+/WTF26+WaJCQ71/mKEOh6RERsp/jh0rQ5OSLtqfunaKQ34IRpSgA/4QEiKhUVE+i1QpuQqq/jExsu6773Q1nxpBPZSVpYPqUs7+4x/SeOiQvmUSEEwIKcAP1AW86q4R51NBpa6LUtOVUJV+pwsL9U1oQyIibNxSwL843Af4QVhSkvS+6y49orLL99u2SX1ZGYf9EFQIKcAPHCEhkpSdLXGDB9s6OjuzZQuVfggqhBTgJxG9e0uv228XR3i4bX2qkGp/qyUg0BFSgB/1vvNOCU9Jsa9Dy5KKd94Rq7XVvj4BPyKkAD9yRERIxowZtvZ57tgxqdu719Y+AX8hpAA/UtV8sTfeKPE332zrXSi+//xzaTvvESBAICKkAD8LS0yUXtnZelRllzNbt0rDoUNU+iHgEVKAAaOp5Dvu0CMqu6gKv1OffHLBndaBQENIAQZQd59InzrV1kq/uj179J0ogEBGSAGGiB85UqL697etPzWKOr5ihbhdLtv6BLobIQUYwhEaKhn//u+2V/o5d++2tU+gOxFSgEHnpmKuv15SJkywrc+2ujpxbtsmrfX1tvUJdCdCCjBIaGyspN5/v0Red51tfapy9Mbycir9EJAIKcAwMVlZknTLLWpoZU+HbrdUfvihngOBhpACDJQ6ebKEXOYZUldCjaTUY+aBQENIAQYKT0qSftOn23rd1MlVq8RVVWVbn0B3IKQAQx/lkTh6tERnZdnW59nDh/Uzpzg3hUBCSAGGiuzbVxLHjLHv3JSIVK1ZI64TJwgqBAxCCjD4uilV6ReVkWFbn63ffy/OXbts6w/oaoQUYLDwxETpk5tr62iq8k9/Equ52bb+gK5ESAGGUzefDY2Ls62/tsZGqfjv/+aQHwICIQUYTgVU5syZ9nVoWdJ46JCcO3rUvj6BLkJIAYHwYMTBgyV+xAjb+jz37bf6Luk8ygOmI6SAABDRu7fE/ehHupjCLic/+EBfN8VhP5iMkAICZDTV7+GHJXHcONv6dJ87JzVFRbb1B3QFQgoIIOkPPmhrf6c//VRaa2tt7ROwEyEFBJDozEzpc889tvXXWlcnx//nfzg3BWMRUkAACYmIkF7/8i8S3ru3PR1aljQcOMBj5mEsQgoIMLGDBkn8sGG29ec6flzfId3d2mpbn4BfQmrZsmUyYsQISUhI0FN2drZs2LDBu76pqUlmz54tKSkpEhcXJ9OmTZOq8+66fOzYMZk0aZLExMRIamqqLFy4UFr54wCuaDTV96GHJCwpyda7UDRzh3QEekj1799flixZIqWlpbJ792658847ZfLkyVJWVqbXz58/X9auXSsFBQVSXFwsJ06ckKlTp3q/v62tTQdUc3OzbNu2TVasWCHLly+XxYsX2//OgCAWdd11kpSdbVt/VmurVK9da1t/gF0c1jVeJJGcnCyvvPKKPPjgg9KnTx95//339dfKoUOHZMiQIVJSUiLjx4/Xo657771Xh1daWppu8/bbb8uvf/1rOXXqlERERHTqZ9bV1UliYqLU1tbqER3QEzV9952UzZmj/vVnS3/hycnyw5dflqj+/W3pD7Djc/yqz0mpUdGqVauksbFRH/ZTo6uWlhbJycnxthk8eLBkZmbqkFLUfPjw4d6AUnJzc/XGekZjHXG5XLpN+wno6dSjPOx8MGLLmTNS+b//K25uPguDXHFI7du3T59vioyMlJkzZ8rq1atl6NChUllZqUdCSecdJ1eBpNYpat4+oDzrPesuJj8/XyeuZ8qw8dEFQKBSd59IuOkmWx/lUb9vn5z95hvb+gO6PaRuvPFG2bt3r+zYsUNmzZoleXl5cuDAAelKixYt0kNCz1RRUdGlPw8IFDE/+IHE/vCHtvXXXF0tzm3bxN3SYlufQLeGlBotDRw4UEaPHq1HOCNHjpTXX39d0tPTdUGE0+n0aa+q+9Q6Rc3Pr/bzvPa06YgatXkqCj0TgH8+Zj7zyScl+oYbbOuz6uOPxXXypG39AX69TsrtdutzRiq0wsPDpajdvcDKy8t1ybk6Z6WouTpcWF1d7W1TWFioQ0cdMgRw5RxhYdL35z+3r0O3W5+bAkwQdqWH3e655x5dDFFfX68r+bZs2SJ/+ctf9LmiGTNmyIIFC3TFnwqeOXPm6GBSlX3KxIkTdRg98sgjsnTpUn0e6rnnntPXVqnREoCru/msukN6wpgxUrd7ty191n/5pTQcPChxQ4bY0h/QLSGlRkC//OUv5eTJkzqU1IW9KqDuuusuvf7VV1+VkJAQfRGvGl2pyr233nrL+/2hoaGybt06fS5LhVdsbKw+p/XSSy9d9RsAIBIWHy8pd94pjQcP6ifv2lHpp24+Gz1ggIRGR9uyjYBfrpPyB66TAi7U2tAg3yxZIvVffWVLf+qOFjc8+6zEDR5sS39At14nBcAsYXFxkv7zn4ujkxfFX06r0ymnN26k0g9+RUgBQUTdeDbOxiKkmi1bxLljh239AVeKkAKCicMh6dOm2def2y1Vq1fzvCn4DSEFBFmln7q4t3durm19qjtQVKqgCrzT1wgChBQQZFQ1XuItt0hYr172dNjWJvV790pLTY09/QFXgJACglDCzTdLtI13M1cVg2ePHGE0hW5HSAFBKCQsTAbMny+O8HDb+vz2tdfE7XLZ1h/QGYQUEKTCe/WSXj/5iW39tZ09KzWbN9vWH9AZhBQQrEJC9D39on/wA3v6a2uTM5s26YuGge5CSAFBXOkX2a+fpNxxh22H/Rq//lpOrV8vltttS3/A5RBSQJAHVa/bbtNP8bWFZemLe5tPn7anP+AyCCkgyEWkpEjqvffqC33tcPbwYX0jWyr90B0IKaAHSL7jDn3ozy7H3n5bF1IAXY2QAnqAkIgISX/wQdv6a2tqklMbNtjWH3AxhBTQQx4zn3DTTZIwapQ9Hba1Se327eI6dcqe/oCLIKSAHiI8OVmSxo2TkKgo2yr9VEk6N59FVyKkgB5U6Zf8059KRJ8+tvV5prhYP8UX6CqEFNCDhEZFScbMmbb11/Tdd+Lcvp1KP3QZQgroYWIHDpTYIUNs6+/EqlXSWltrW39Ae4QU0MOoc1L6uimbuM+e/eddKBhNoQsQUkAPPDelnjfV++67bbnAVxVOOHfulKZjx2zZPqA9QgrogUIjIyV10iR9p3Q7nPvmG/m+pETcra229Ad4EFJADxV9/fW2Psrj9KefSmtdnW39AQohBfRgaZMnS2hcnC19tZw+Lac3bODcFGxFSAE9WFivXpJ6//229Xdq40ZpqqiwrT+AkAJ6MPWYeXUXiuisLFv6a62vl5qiIkZTsA0hBfRw0QMG/POefqGh196Z2y0Nhw6Jq7LSjk0DCCmgp1Ml6ekPPCBh8fG29KeeNXX2669t6QsgpABIaHy89H3oIX9vBnABQgqAps5NRaSl+XszAB+EFAB9yE89ysPO2yUBdiCkAHgfjNjrttsk7kc/uraOVAGGHUUYACEFoD11myRV6ee4hpBJHDtWPwUYsAMhBcC30u/BByXlrruu6vtDoqMl+bbbJMymu1gAhBSAC6igilejoSu4S7ojLEz6TJokSdnZXbpt6FkIKQAXjKbUI+Yz/u3fJHHMmE4FVUhMjPS5917p+6//qu9iAdiF3yYAHQZVdGam9J8xQwfWmeJiaWts7LCtKltXFwOrO6qHRkd3+7YiuBFSAC4qsm9fue7RRyUlJ0ec27dL/f790nLmjD60F3XddbpAQo22IlJS9DLAbvxWAbjkiCo0KkpibrhBTxdrA3QVQgrAZRFE8BcKJwAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADBGVJLliwRh8Mh8+bN8y5ramqS2bNnS0pKisTFxcm0adOkqqrK5/uOHTsmkyZNkpiYGElNTZWFCxdKa2vrtWwKACAIXXVI7dq1S9555x0ZMWKEz/L58+fL2rVrpaCgQIqLi+XEiRMydepU7/q2tjYdUM3NzbJt2zZZsWKFLF++XBYvXnxt7wQAEHysq1BfX28NGjTIKiwstG6//XZr7ty5ernT6bTCw8OtgoICb9uDBw9a6seUlJTo1+vXr7dCQkKsyspKb5tly5ZZCQkJlsvl6tTPr62t1X2qOQAg8HT2c/yqRlLqcJ4aDeXk5PgsLy0tlZaWFp/lgwcPlszMTCkpKdGv1Xz48OGSlpbmbZObmyt1dXVSVlbW4c9zuVx6ffsJABD8wq70G1atWiV79uzRh/vOV1lZKREREZKUlOSzXAWSWudp0z6gPOs96zqSn58vL7744pVuKgAgwF3RSKqiokLmzp0r7733nkRFRUl3WbRokdTW1nontR0AgOB3RSGlDudVV1fLqFGjJCwsTE+qOOKNN97QX6sRkSqIcDqdPt+nqvvS09P112p+frWf57WnzfkiIyMlISHBZwIABL8rCqkJEybIvn37ZO/evd5pzJgxMn36dO/X4eHhUlRU5P2e8vJyXXKenZ2tX6u56kOFnUdhYaEOnqFDh9r53gAAPemcVHx8vAwbNsxnWWxsrL4myrN8xowZsmDBAklOTtbBM2fOHB1M48eP1+snTpyow+iRRx6RpUuX6vNQzz33nC7GUCMmAACuunDicl599VUJCQnRF/GqqjxVuffWW29514eGhsq6detk1qxZOrxUyOXl5clLL71k96YAAAKcQ9WhS4BRJeiJiYm6iILzUwAQeDr7Oc69+wAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGIqQAAMYipAAAxiKkAADGCpMAZFmWntfV1fl7UwAAV8Hz+e35PA+qkKqpqdHzjIwMf28KAOAa1NfXS2JiYnCFVHJysp4fO3bskm+up1P/UlFBXlFRIQkJCf7eHGOxnzqH/dQ57KfOUSMoFVD9+vW7ZLuADKmQkH+eSlMBxS/B5al9xH66PPZT57CfOof9dHmdGWRQOAEAMBYhBQAwVkCGVGRkpLzwwgt6jotjP3UO+6lz2E+dw36yl8O6XP0fAAB+EpAjKQBAz0BIAQCMRUgBAIxFSAEAjBWQIfXmm2/KgAEDJCoqSsaNGyc7d+6UnmTr1q1y33336Su1HQ6HrFmzxme9qoVZvHix9O3bV6KjoyUnJ0cOHz7s0+bMmTMyffp0fbFhUlKSzJgxQxoaGiRY5Ofny9ixYyU+Pl5SU1NlypQpUl5e7tOmqalJZs+eLSkpKRIXFyfTpk2TqqoqnzbqriaTJk2SmJgY3c/ChQultbVVgsWyZctkxIgR3gtPs7OzZcOGDd717KOOLVmyRP/tzZs3z7uMfdVFrACzatUqKyIiwvrjH/9olZWVWY8//riVlJRkVVVVWT3F+vXrrd/85jfWhx9+qCozrdWrV/usX7JkiZWYmGitWbPG+vLLL63777/fysrKss6dO+dtc/fdd1sjR460tm/fbv3tb3+zBg4caD388MNWsMjNzbXeffdda//+/dbevXutn/3sZ1ZmZqbV0NDgbTNz5kwrIyPDKioqsnbv3m2NHz/e+vGPf+xd39raag0bNszKycmxvvjiC73fe/fubS1atMgKFh9//LH1ySefWF9//bVVXl5uPfvss1Z4eLjebwr76EI7d+60BgwYYI0YMcKaO3eudzn7qmsEXEjdcsst1uzZs72v29rarH79+ln5+flWT3R+SLndbis9Pd165ZVXvMucTqcVGRlprVy5Ur8+cOCA/r5du3Z522zYsMFyOBzW8ePHrWBUXV2t33NxcbF3n6gP44KCAm+bgwcP6jYlJSX6tfoQCQkJsSorK71tli1bZiUkJFgul8sKVr169bL+8Ic/sI86UF9fbw0aNMgqLCy0br/9dm9Isa+6TkAd7mtubpbS0lJ9+Kr9ffzU65KSEr9umymOHDkilZWVPvtI3R9LHRb17CM1V4f4xowZ422j2qt9uWPHDglGtbW1PjcnVr9HLS0tPvtp8ODBkpmZ6bOfhg8fLmlpad42ubm5+gaiZWVlEmza2tpk1apV0tjYqA/7sY8upA7nqcN17feJwr7qOgF1g9nTp0/rP6T2/5MV9frQoUN+2y6TqIBSOtpHnnVqro6HtxcWFqY/wD1tgonb7dbnDm699VYZNmyYXqbeZ0REhA7rS+2njvajZ12w2Ldvnw4ldU5FnUtZvXq1DB06VPbu3cs+akcF+J49e2TXrl0XrOP3qesEVEgBV/uv3/3798tnn33m700x0o033qgDSY02//znP0teXp4UFxf7e7OMoh67MXfuXCksLNQFW+g+AXW4r3fv3hIaGnpBxYx6nZ6e7rftMolnP1xqH6l5dXW1z3pVYaQq/oJtPz711FOybt062bx5s/Tv39+7XL1PdfjY6XRecj91tB8964KFGgEMHDhQRo8erasiR44cKa+//jr76LzDeepvZtSoUfqog5pUkL/xxhv6azUiYl91jZBA+2NSf0hFRUU+h3LUa3W4AiJZWVn6F779PlLHvNW5Js8+UnP1x6T+8Dw2bdqk96U6dxUMVE2JCih16Eq9N7Vf2lO/R+Hh4T77SZWoqxLh9vtJHQprH+jqX9KqVFsdDgtW6vfA5XKxj9qZMGGCfp9qxOmZ1DlddRmH52v2VRexArAEXVWqLV++XFepPfHEE7oEvX3FTLBTFUaqhFVN6n/hb3/7W/310aNHvSXoap989NFH1ldffWVNnjy5wxL0m2++2dqxY4f12Wef6YqlYCpBnzVrli7D37Jli3Xy5EnvdPbsWZ+SYVWWvmnTJl0ynJ2drafzS4YnTpyoy9g3btxo9enTJ6hKhp955hld8XjkyBH9u6JeqyrPTz/9VK9nH11c++o+hX3VNQIupJTf/e53+pdBXS+lStLVtT49yebNm3U4nT/l5eV5y9Cff/55Ky0tTQf6hAkT9DUw7dXU1OhQiouL0yWwjz32mA6/YNHR/lGTunbKQ4X2k08+qUuuY2JirAceeEAHWXvffvutdc8991jR0dH6mpann37aamlpsYLFr371K+v666/Xf0vqA1P9rngCSmEfdT6k2Fddg0d1AACMFVDnpAAAPQshBQAwFiEFADAWIQUAMBYhBQAwFiEFADAWIQUAMBYhBQAwFiEFADAWIQUAMBYhBQAwFiEFABBT/R+ofEsWaWoXagAAAABJRU5ErkJggg=="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:39.068817Z",
     "start_time": "2025-09-03T08:53:38.431050Z"
    }
   },
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#定义模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc_statu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.fc_mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "        self.fc_std = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Softplus(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.fc_statu(state)\n",
    "\n",
    "        mu = self.fc_mu(state) * 2.0\n",
    "        std = self.fc_std(state)\n",
    "\n",
    "        return mu, std\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model_td = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 1),\n",
    ")\n",
    "\n",
    "model(torch.randn(2, 3)), model_td(torch.randn(2, 3))"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([[-0.7190],\n",
       "          [-0.2408]], grad_fn=<MulBackward0>),\n",
       "  tensor([[0.9482],\n",
       "          [0.8217]], grad_fn=<SoftplusBackward0>)),\n",
       " tensor([[0.1461],\n",
       "         [0.1278]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:39.079740Z",
     "start_time": "2025-09-03T08:53:39.075065Z"
    }
   },
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    mu, std = model(state)\n",
    "\n",
    "    #根据概率选择一个动作\n",
    "    #action = random.normalvariate(mu=mu.item(), sigma=std.item())\n",
    "    action = torch.distributions.Normal(mu, std).sample().item()\n",
    "\n",
    "    return action\n",
    "\n",
    "\n",
    "get_action([1, 2, 3])"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1.7710306644439697"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:39.124311Z",
     "start_time": "2025-09-03T08:53:39.100773Z"
    }
   },
   "source": [
    "def get_data():\n",
    "    states = []\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    next_states = []\n",
    "    overs = []\n",
    "\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #记录数据样本\n",
    "        states.append(state)\n",
    "        rewards.append(reward)\n",
    "        actions.append(action)\n",
    "        next_states.append(next_state)\n",
    "        overs.append(over)\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #[b, 3]\n",
    "    states = torch.FloatTensor(states).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    rewards = torch.FloatTensor(rewards).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    actions = torch.FloatTensor(actions).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_states = torch.FloatTensor(next_states).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    overs = torch.LongTensor(overs).reshape(-1, 1)\n",
    "\n",
    "    return states, rewards, actions, next_states, overs\n",
    "\n",
    "\n",
    "get_data()"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/s9/yht5_svd6mxft5fpm7ht48f80000gn/T/ipykernel_70786/277701111.py:31: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:256.)\n",
      "  states = torch.FloatTensor(states).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-5.8219e-02, -9.9830e-01, -5.4053e-01],\n",
       "         [-1.2222e-01, -9.9250e-01, -1.2855e+00],\n",
       "         [-2.2270e-01, -9.7489e-01, -2.0412e+00],\n",
       "         [-3.5320e-01, -9.3555e-01, -2.7281e+00],\n",
       "         [-5.1124e-01, -8.5944e-01, -3.5129e+00],\n",
       "         [-6.7750e-01, -7.3552e-01, -4.1546e+00],\n",
       "         [-8.2667e-01, -5.6269e-01, -4.5759e+00],\n",
       "         [-9.4257e-01, -3.3401e-01, -5.1417e+00],\n",
       "         [-9.9763e-01, -6.8863e-02, -5.4327e+00],\n",
       "         [-9.7932e-01,  2.0234e-01, -5.4533e+00],\n",
       "         [-8.9561e-01,  4.4483e-01, -5.1449e+00],\n",
       "         [-7.6215e-01,  6.4740e-01, -4.8634e+00],\n",
       "         [-5.9843e-01,  8.0117e-01, -4.5018e+00],\n",
       "         [-4.4529e-01,  8.9539e-01, -3.6010e+00],\n",
       "         [-3.1174e-01,  9.5017e-01, -2.8895e+00],\n",
       "         [-2.0532e-01,  9.7870e-01, -2.2047e+00],\n",
       "         [-1.2941e-01,  9.9159e-01, -1.5404e+00],\n",
       "         [-9.0604e-02,  9.9589e-01, -7.8085e-01],\n",
       "         [-9.8803e-02,  9.9511e-01,  1.6471e-01],\n",
       "         [-1.3320e-01,  9.9109e-01,  6.9272e-01],\n",
       "         [-2.0682e-01,  9.7838e-01,  1.4944e+00],\n",
       "         [-3.0596e-01,  9.5204e-01,  2.0526e+00],\n",
       "         [-4.3130e-01,  9.0221e-01,  2.6995e+00],\n",
       "         [-5.7842e-01,  8.1574e-01,  3.4171e+00],\n",
       "         [-7.2121e-01,  6.9271e-01,  3.7754e+00],\n",
       "         [-8.5762e-01,  5.1428e-01,  4.5016e+00],\n",
       "         [-9.5434e-01,  2.9872e-01,  4.7362e+00],\n",
       "         [-9.9810e-01,  6.1606e-02,  4.8342e+00],\n",
       "         [-9.8435e-01, -1.7620e-01,  4.7754e+00],\n",
       "         [-9.1427e-01, -4.0510e-01,  4.7993e+00],\n",
       "         [-8.0742e-01, -5.8998e-01,  4.2788e+00],\n",
       "         [-6.8368e-01, -7.2978e-01,  3.7394e+00],\n",
       "         [-5.6346e-01, -8.2614e-01,  3.0845e+00],\n",
       "         [-4.7091e-01, -8.8218e-01,  2.1649e+00],\n",
       "         [-3.9809e-01, -9.1735e-01,  1.6179e+00],\n",
       "         [-3.6210e-01, -9.3214e-01,  7.7814e-01],\n",
       "         [-3.5559e-01, -9.3464e-01,  1.3956e-01],\n",
       "         [-3.8235e-01, -9.2402e-01, -5.7592e-01],\n",
       "         [-4.3891e-01, -8.9853e-01, -1.2408e+00],\n",
       "         [-5.2294e-01, -8.5237e-01, -1.9182e+00],\n",
       "         [-6.2591e-01, -7.7990e-01, -2.5200e+00],\n",
       "         [-7.4493e-01, -6.6714e-01, -3.2827e+00],\n",
       "         [-8.5556e-01, -5.1771e-01, -3.7239e+00],\n",
       "         [-9.4139e-01, -3.3732e-01, -4.0021e+00],\n",
       "         [-9.9224e-01, -1.2434e-01, -4.3880e+00],\n",
       "         [-9.9523e-01,  9.7534e-02, -4.4471e+00],\n",
       "         [-9.5371e-01,  3.0072e-01, -4.1552e+00],\n",
       "         [-8.7398e-01,  4.8597e-01, -4.0404e+00],\n",
       "         [-7.6490e-01,  6.4415e-01, -3.8489e+00],\n",
       "         [-6.4563e-01,  7.6365e-01, -3.3807e+00],\n",
       "         [-5.2790e-01,  8.4931e-01, -2.9145e+00],\n",
       "         [-4.2550e-01,  9.0496e-01, -2.3321e+00],\n",
       "         [-3.3918e-01,  9.4072e-01, -1.8695e+00],\n",
       "         [-2.7908e-01,  9.6027e-01, -1.2640e+00],\n",
       "         [-2.3981e-01,  9.7082e-01, -8.1344e-01],\n",
       "         [-2.3409e-01,  9.7221e-01, -1.1773e-01],\n",
       "         [-2.5320e-01,  9.6741e-01,  3.9413e-01],\n",
       "         [-2.9879e-01,  9.5432e-01,  9.4881e-01],\n",
       "         [-3.7081e-01,  9.2871e-01,  1.5290e+00],\n",
       "         [-4.5996e-01,  8.8794e-01,  1.9614e+00],\n",
       "         [-5.7458e-01,  8.1845e-01,  2.6829e+00],\n",
       "         [-7.0397e-01,  7.1023e-01,  3.3777e+00],\n",
       "         [-8.2005e-01,  5.7230e-01,  3.6104e+00],\n",
       "         [-9.1379e-01,  4.0620e-01,  3.8203e+00],\n",
       "         [-9.7494e-01,  2.2247e-01,  3.8787e+00],\n",
       "         [-9.9964e-01,  2.6768e-02,  3.9516e+00],\n",
       "         [-9.8564e-01, -1.6888e-01,  3.9292e+00],\n",
       "         [-9.3589e-01, -3.5228e-01,  3.8064e+00],\n",
       "         [-8.5995e-01, -5.1038e-01,  3.5124e+00],\n",
       "         [-7.7939e-01, -6.2654e-01,  2.8296e+00],\n",
       "         [-7.0392e-01, -7.1028e-01,  2.2558e+00],\n",
       "         [-6.4231e-01, -7.6645e-01,  1.6679e+00],\n",
       "         [-5.9569e-01, -8.0321e-01,  1.1875e+00],\n",
       "         [-5.6950e-01, -8.2199e-01,  6.4456e-01],\n",
       "         [-5.6684e-01, -8.2383e-01,  6.4653e-02],\n",
       "         [-5.9181e-01, -8.0608e-01, -6.1256e-01],\n",
       "         [-6.4028e-01, -7.6814e-01, -1.2314e+00],\n",
       "         [-7.0666e-01, -7.0755e-01, -1.7981e+00],\n",
       "         [-7.8145e-01, -6.2397e-01, -2.2442e+00],\n",
       "         [-8.5537e-01, -5.1802e-01, -2.5856e+00],\n",
       "         [-9.2347e-01, -3.8367e-01, -3.0154e+00],\n",
       "         [-9.7588e-01, -2.1829e-01, -3.4741e+00],\n",
       "         [-9.9918e-01, -4.0569e-02, -3.5896e+00],\n",
       "         [-9.8962e-01,  1.4374e-01, -3.6964e+00],\n",
       "         [-9.4501e-01,  3.2704e-01, -3.7786e+00],\n",
       "         [-8.6994e-01,  4.9316e-01, -3.6510e+00],\n",
       "         [-7.7567e-01,  6.3114e-01, -3.3461e+00],\n",
       "         [-6.7315e-01,  7.3950e-01, -2.9861e+00],\n",
       "         [-5.7409e-01,  8.1879e-01, -2.5394e+00],\n",
       "         [-4.8772e-01,  8.7300e-01, -2.0403e+00],\n",
       "         [-4.3151e-01,  9.0211e-01, -1.2663e+00],\n",
       "         [-4.0091e-01,  9.1612e-01, -6.7316e-01],\n",
       "         [-4.0350e-01,  9.1498e-01,  5.6671e-02],\n",
       "         [-4.3591e-01,  8.9999e-01,  7.1414e-01],\n",
       "         [-4.9868e-01,  8.6679e-01,  1.4205e+00],\n",
       "         [-5.8125e-01,  8.1372e-01,  1.9639e+00],\n",
       "         [-6.7827e-01,  7.3481e-01,  2.5028e+00],\n",
       "         [-7.7271e-01,  6.3475e-01,  2.7539e+00],\n",
       "         [-8.5710e-01,  5.1516e-01,  2.9300e+00],\n",
       "         [-9.3419e-01,  3.5677e-01,  3.5276e+00],\n",
       "         [-9.8199e-01,  1.8891e-01,  3.4951e+00],\n",
       "         [-1.0000e+00,  2.0086e-03,  3.7609e+00],\n",
       "         [-9.8444e-01, -1.7573e-01,  3.5731e+00],\n",
       "         [-9.3714e-01, -3.4894e-01,  3.5959e+00],\n",
       "         [-8.7043e-01, -4.9229e-01,  3.1655e+00],\n",
       "         [-7.9780e-01, -6.0292e-01,  2.6489e+00],\n",
       "         [-7.2755e-01, -6.8606e-01,  2.1780e+00],\n",
       "         [-6.5974e-01, -7.5150e-01,  1.8853e+00],\n",
       "         [-6.0516e-01, -7.9611e-01,  1.4102e+00],\n",
       "         [-5.7328e-01, -8.1936e-01,  7.8923e-01],\n",
       "         [-5.6763e-01, -8.2329e-01,  1.3758e-01],\n",
       "         [-5.8269e-01, -8.1270e-01, -3.6824e-01],\n",
       "         [-6.1757e-01, -7.8652e-01, -8.7222e-01],\n",
       "         [-6.7264e-01, -7.3997e-01, -1.4425e+00],\n",
       "         [-7.4731e-01, -6.6448e-01, -2.1246e+00],\n",
       "         [-8.2555e-01, -5.6433e-01, -2.5434e+00],\n",
       "         [-9.0264e-01, -4.3040e-01, -3.0939e+00],\n",
       "         [-9.6315e-01, -2.6897e-01, -3.4521e+00],\n",
       "         [-9.9593e-01, -9.0100e-02, -3.6420e+00],\n",
       "         [-9.9491e-01,  1.0081e-01, -3.8241e+00],\n",
       "         [-9.6049e-01,  2.7833e-01, -3.6214e+00],\n",
       "         [-8.9845e-01,  4.3908e-01, -3.4504e+00],\n",
       "         [-8.1471e-01,  5.7987e-01, -3.2800e+00],\n",
       "         [-7.2693e-01,  6.8671e-01, -2.7677e+00],\n",
       "         [-6.4248e-01,  7.6630e-01, -2.3222e+00],\n",
       "         [-5.6516e-01,  8.2498e-01, -1.9420e+00],\n",
       "         [-5.1429e-01,  8.5761e-01, -1.2089e+00],\n",
       "         [-4.9186e-01,  8.7068e-01, -5.1930e-01],\n",
       "         [-4.9597e-01,  8.6834e-01,  9.4548e-02],\n",
       "         [-5.2453e-01,  8.5139e-01,  6.6428e-01],\n",
       "         [-5.7411e-01,  8.1878e-01,  1.1870e+00],\n",
       "         [-6.3547e-01,  7.7213e-01,  1.5420e+00],\n",
       "         [-7.1720e-01,  6.9686e-01,  2.2234e+00],\n",
       "         [-8.0272e-01,  5.9636e-01,  2.6411e+00],\n",
       "         [-8.8717e-01,  4.6144e-01,  3.1867e+00],\n",
       "         [-9.5257e-01,  3.0433e-01,  3.4077e+00],\n",
       "         [-9.9017e-01,  1.3985e-01,  3.3786e+00],\n",
       "         [-9.9930e-01, -3.7501e-02,  3.5564e+00],\n",
       "         [-9.7767e-01, -2.1014e-01,  3.4842e+00],\n",
       "         [-9.3286e-01, -3.6025e-01,  3.1362e+00],\n",
       "         [-8.7237e-01, -4.8884e-01,  2.8445e+00],\n",
       "         [-8.0098e-01, -5.9869e-01,  2.6220e+00],\n",
       "         [-7.3233e-01, -6.8095e-01,  2.1438e+00],\n",
       "         [-6.8097e-01, -7.3231e-01,  1.4530e+00],\n",
       "         [-6.3773e-01, -7.7026e-01,  1.1510e+00],\n",
       "         [-6.1341e-01, -7.8977e-01,  6.2345e-01],\n",
       "         [-6.1802e-01, -7.8616e-01, -1.1705e-01],\n",
       "         [-6.4217e-01, -7.6656e-01, -6.2219e-01],\n",
       "         [-6.8289e-01, -7.3052e-01, -1.0876e+00],\n",
       "         [-7.4342e-01, -6.6883e-01, -1.7291e+00],\n",
       "         [-8.1567e-01, -5.7852e-01, -2.3143e+00],\n",
       "         [-8.8508e-01, -4.6543e-01, -2.6558e+00],\n",
       "         [-9.4335e-01, -3.3179e-01, -2.9184e+00],\n",
       "         [-9.8300e-01, -1.8360e-01, -3.0710e+00],\n",
       "         [-9.9974e-01, -2.2914e-02, -3.2347e+00],\n",
       "         [-9.9080e-01,  1.3530e-01, -3.1726e+00],\n",
       "         [-9.5530e-01,  2.9565e-01, -3.2885e+00],\n",
       "         [-9.0092e-01,  4.3399e-01, -2.9756e+00],\n",
       "         [-8.3798e-01,  5.4570e-01, -2.5661e+00],\n",
       "         [-7.7672e-01,  6.2985e-01, -2.0828e+00],\n",
       "         [-7.2161e-01,  6.9230e-01, -1.6662e+00],\n",
       "         [-6.8019e-01,  7.3303e-01, -1.1619e+00],\n",
       "         [-6.5223e-01,  7.5802e-01, -7.5011e-01],\n",
       "         [-6.4871e-01,  7.6103e-01, -9.2548e-02],\n",
       "         [-6.6711e-01,  7.4496e-01,  4.8867e-01],\n",
       "         [-6.9772e-01,  7.1637e-01,  8.3765e-01],\n",
       "         [-7.5088e-01,  6.6044e-01,  1.5437e+00],\n",
       "         [-8.1282e-01,  5.8252e-01,  1.9915e+00],\n",
       "         [-8.7509e-01,  4.8395e-01,  2.3332e+00],\n",
       "         [-9.3019e-01,  3.6708e-01,  2.5859e+00],\n",
       "         [-9.7376e-01,  2.2756e-01,  2.9260e+00],\n",
       "         [-9.9702e-01,  7.7184e-02,  3.0462e+00],\n",
       "         [-9.9648e-01, -8.3846e-02,  3.2241e+00],\n",
       "         [-9.7434e-01, -2.2506e-01,  2.8612e+00],\n",
       "         [-9.3212e-01, -3.6214e-01,  2.8711e+00],\n",
       "         [-8.7756e-01, -4.7947e-01,  2.5898e+00],\n",
       "         [-8.1629e-01, -5.7764e-01,  2.3157e+00],\n",
       "         [-7.5861e-01, -6.5154e-01,  1.8756e+00],\n",
       "         [-7.1276e-01, -7.0140e-01,  1.3550e+00],\n",
       "         [-6.8027e-01, -7.3296e-01,  9.0604e-01],\n",
       "         [-6.6070e-01, -7.5065e-01,  5.2761e-01],\n",
       "         [-6.6140e-01, -7.5004e-01, -1.8606e-02],\n",
       "         [-6.8172e-01, -7.3162e-01, -5.4852e-01],\n",
       "         [-7.1458e-01, -6.9956e-01, -9.1830e-01],\n",
       "         [-7.6490e-01, -6.4415e-01, -1.4973e+00],\n",
       "         [-8.1892e-01, -5.7391e-01, -1.7727e+00],\n",
       "         [-8.7718e-01, -4.8016e-01, -2.2086e+00],\n",
       "         [-9.3102e-01, -3.6496e-01, -2.5450e+00],\n",
       "         [-9.7202e-01, -2.3488e-01, -2.7299e+00],\n",
       "         [-9.9538e-01, -9.6020e-02, -2.8185e+00],\n",
       "         [-9.9852e-01,  5.4412e-02, -3.0121e+00],\n",
       "         [-9.8143e-01,  1.9185e-01, -2.7721e+00],\n",
       "         [-9.5070e-01,  3.1012e-01, -2.4456e+00],\n",
       "         [-9.0940e-01,  4.1593e-01, -2.2729e+00],\n",
       "         [-8.6287e-01,  5.0543e-01, -2.0182e+00],\n",
       "         [-8.1206e-01,  5.8358e-01, -1.8651e+00],\n",
       "         [-7.7236e-01,  6.3518e-01, -1.3022e+00],\n",
       "         [-7.4084e-01,  6.7168e-01, -9.6477e-01],\n",
       "         [-7.2681e-01,  6.8684e-01, -4.1304e-01],\n",
       "         [-7.2547e-01,  6.8825e-01, -3.8912e-02]]),\n",
       " tensor([[ -2.6830],\n",
       "         [ -3.0326],\n",
       "         [ -3.6401],\n",
       "         [ -4.4763],\n",
       "         [ -5.6753],\n",
       "         [ -7.0868],\n",
       "         [ -8.5665],\n",
       "         [-10.4896],\n",
       "         [-12.3928],\n",
       "         [-11.6058],\n",
       "         [ -9.8327],\n",
       "         [ -8.3071],\n",
       "         [ -6.9251],\n",
       "         [ -5.4270],\n",
       "         [ -4.3988],\n",
       "         [ -3.6461],\n",
       "         [ -3.1292],\n",
       "         [ -2.8234],\n",
       "         [ -2.7929],\n",
       "         [ -2.9531],\n",
       "         [ -3.3899],\n",
       "         [ -3.9625],\n",
       "         [ -4.7960],\n",
       "         [ -5.9560],\n",
       "         [ -7.0743],\n",
       "         [ -8.7948],\n",
       "         [-10.2995],\n",
       "         [-11.8235],\n",
       "         [-11.0696],\n",
       "         [ -9.7284],\n",
       "         [ -8.1342],\n",
       "         [ -6.7979],\n",
       "         [ -5.6616],\n",
       "         [ -4.7175],\n",
       "         [ -4.1841],\n",
       "         [ -3.8294],\n",
       "         [ -3.7436],\n",
       "         [ -3.8871],\n",
       "         [ -4.2553],\n",
       "         [ -4.8670],\n",
       "         [ -5.6859],\n",
       "         [ -6.8918],\n",
       "         [ -8.1339],\n",
       "         [ -9.4287],\n",
       "         [-11.0273],\n",
       "         [-11.2451],\n",
       "         [ -9.7708],\n",
       "         [ -8.5724],\n",
       "         [ -7.4432],\n",
       "         [ -6.3083],\n",
       "         [ -5.3733],\n",
       "         [ -4.5873],\n",
       "         [ -4.0242],\n",
       "         [ -3.5990],\n",
       "         [ -3.3531],\n",
       "         [ -3.2690],\n",
       "         [ -3.3540],\n",
       "         [ -3.6036],\n",
       "         [ -4.0420],\n",
       "         [ -4.5822],\n",
       "         [ -5.4851],\n",
       "         [ -6.6757],\n",
       "         [ -7.7181],\n",
       "         [ -8.8786],\n",
       "         [-10.0152],\n",
       "         [-11.2637],\n",
       "         [-10.3761],\n",
       "         [ -9.1861],\n",
       "         [ -8.0287],\n",
       "         [ -6.8748],\n",
       "         [ -6.0395],\n",
       "         [ -5.4238],\n",
       "         [ -5.0205],\n",
       "         [ -4.7796],\n",
       "         [ -4.7245],\n",
       "         [ -4.8956],\n",
       "         [ -5.2848],\n",
       "         [ -5.8723],\n",
       "         [ -6.5943],\n",
       "         [ -7.4133],\n",
       "         [ -8.4611],\n",
       "         [ -9.7424],\n",
       "         [-10.9050],\n",
       "         [-10.3521],\n",
       "         [ -9.3156],\n",
       "         [ -8.2284],\n",
       "         [ -7.1648],\n",
       "         [ -6.2249],\n",
       "         [ -5.4078],\n",
       "         [ -4.7444],\n",
       "         [ -4.2288],\n",
       "         [ -3.9789],\n",
       "         [ -3.9451],\n",
       "         [ -4.1389],\n",
       "         [ -4.5824],\n",
       "         [ -5.1867],\n",
       "         [ -5.9952],\n",
       "         [ -6.7840],\n",
       "         [ -7.6226],\n",
       "         [ -8.9589],\n",
       "         [ -9.9339],\n",
       "         [-11.2730],\n",
       "         [-10.0687],\n",
       "         [ -9.0514],\n",
       "         [ -7.9035],\n",
       "         [ -6.9239],\n",
       "         [ -6.1673],\n",
       "         [ -5.6057],\n",
       "         [ -5.1307],\n",
       "         [ -4.8204],\n",
       "         [ -4.7305],\n",
       "         [ -4.8226],\n",
       "         [ -5.0778],\n",
       "         [ -5.5383],\n",
       "         [ -6.2829],\n",
       "         [ -7.1092],\n",
       "         [ -8.2293],\n",
       "         [ -9.4244],\n",
       "         [-10.6379],\n",
       "         [-10.7084],\n",
       "         [ -9.4885],\n",
       "         [ -8.4117],\n",
       "         [ -7.4417],\n",
       "         [ -6.4527],\n",
       "         [ -5.6871],\n",
       "         [ -5.0928],\n",
       "         [ -4.6025],\n",
       "         [ -4.3743],\n",
       "         [ -4.3682],\n",
       "         [ -4.5517],\n",
       "         [ -4.9064],\n",
       "         [ -5.3432],\n",
       "         [ -6.1145],\n",
       "         [ -6.9612],\n",
       "         [ -8.1023],\n",
       "         [ -9.1864],\n",
       "         [-10.1494],\n",
       "         [-10.9002],\n",
       "         [ -9.7997],\n",
       "         [ -8.6735],\n",
       "         [ -7.7313],\n",
       "         [ -6.9362],\n",
       "         [ -6.1853],\n",
       "         [ -5.5957],\n",
       "         [ -5.2508],\n",
       "         [ -5.0179],\n",
       "         [ -5.0059],\n",
       "         [ -5.1837],\n",
       "         [ -5.5127],\n",
       "         [ -6.1024],\n",
       "         [ -6.9100],\n",
       "         [ -7.7678],\n",
       "         [ -8.7111],\n",
       "         [ -9.6866],\n",
       "         [-10.7728],\n",
       "         [-10.0440],\n",
       "         [ -9.1556],\n",
       "         [ -8.1362],\n",
       "         [ -7.2347],\n",
       "         [ -6.4867],\n",
       "         [ -5.9274],\n",
       "         [ -5.5128],\n",
       "         [ -5.2610],\n",
       "         [ -5.1842],\n",
       "         [ -5.3210],\n",
       "         [ -5.5611],\n",
       "         [ -6.0957],\n",
       "         [ -6.7462],\n",
       "         [ -7.4957],\n",
       "         [ -8.3181],\n",
       "         [ -9.3362],\n",
       "         [-10.3187],\n",
       "         [-10.3927],\n",
       "         [ -9.3149],\n",
       "         [ -8.5029],\n",
       "         [ -7.6488],\n",
       "         [ -6.9157],\n",
       "         [ -6.2664],\n",
       "         [ -5.7734],\n",
       "         [ -5.4608],\n",
       "         [ -5.2836],\n",
       "         [ -5.2601],\n",
       "         [ -5.4181],\n",
       "         [ -5.6863],\n",
       "         [ -6.1879],\n",
       "         [ -6.7168],\n",
       "         [ -7.4614],\n",
       "         [ -8.3099],\n",
       "         [ -9.1817],\n",
       "         [-10.0697],\n",
       "         [-10.4396],\n",
       "         [ -9.4639],\n",
       "         [ -8.5860],\n",
       "         [ -7.8751],\n",
       "         [ -7.2307],\n",
       "         [ -6.6912],\n",
       "         [ -6.1894],\n",
       "         [ -5.8777],\n",
       "         [ -5.7036],\n",
       "         [ -5.6769]]),\n",
       " tensor([[ 0.0252],\n",
       "         [-0.0754],\n",
       "         [ 0.2951],\n",
       "         [-0.5541],\n",
       "         [ 0.0188],\n",
       "         [ 0.8692],\n",
       "         [-0.9590],\n",
       "         [-0.2699],\n",
       "         [ 0.2074],\n",
       "         [ 1.0443],\n",
       "         [-0.3479],\n",
       "         [-0.8263],\n",
       "         [ 2.2734],\n",
       "         [ 0.2664],\n",
       "         [-0.1856],\n",
       "         [-0.4649],\n",
       "         [ 0.1056],\n",
       "         [ 1.3243],\n",
       "         [-1.4554],\n",
       "         [ 0.3891],\n",
       "         [-1.1704],\n",
       "         [-0.4474],\n",
       "         [ 0.2729],\n",
       "         [-1.6906],\n",
       "         [ 1.3778],\n",
       "         [-1.0070],\n",
       "         [-0.8403],\n",
       "         [-0.7001],\n",
       "         [ 1.0406],\n",
       "         [-1.4446],\n",
       "         [-0.6466],\n",
       "         [-0.7166],\n",
       "         [-2.7278],\n",
       "         [ 0.7639],\n",
       "         [-1.0114],\n",
       "         [ 0.4035],\n",
       "         [-0.0967],\n",
       "         [ 0.1874],\n",
       "         [-0.0233],\n",
       "         [ 0.2500],\n",
       "         [-1.1852],\n",
       "         [ 0.3944],\n",
       "         [ 0.7337],\n",
       "         [-0.8858],\n",
       "         [ 0.2276],\n",
       "         [ 1.4584],\n",
       "         [-0.7382],\n",
       "         [-1.1533],\n",
       "         [-0.0996],\n",
       "         [-0.7100],\n",
       "         [-0.3638],\n",
       "         [-1.4409],\n",
       "         [-0.6670],\n",
       "         [-1.7975],\n",
       "         [-0.2160],\n",
       "         [-1.4487],\n",
       "         [-1.1392],\n",
       "         [-0.9036],\n",
       "         [-1.7611],\n",
       "         [ 0.3704],\n",
       "         [ 0.5397],\n",
       "         [-2.1478],\n",
       "         [-1.4618],\n",
       "         [-1.6414],\n",
       "         [-0.6269],\n",
       "         [-0.2829],\n",
       "         [ 0.0257],\n",
       "         [-0.1989],\n",
       "         [-2.0673],\n",
       "         [-0.6924],\n",
       "         [-0.3681],\n",
       "         [ 0.6297],\n",
       "         [ 0.3964],\n",
       "         [ 0.2439],\n",
       "         [-0.3956],\n",
       "         [-0.0949],\n",
       "         [ 0.0625],\n",
       "         [ 0.5635],\n",
       "         [ 0.8438],\n",
       "         [-0.2747],\n",
       "         [-1.1402],\n",
       "         [ 0.3218],\n",
       "         [-0.5094],\n",
       "         [-1.2667],\n",
       "         [-0.7848],\n",
       "         [-0.4328],\n",
       "         [-0.7558],\n",
       "         [-0.7198],\n",
       "         [-0.7660],\n",
       "         [ 0.7949],\n",
       "         [-0.5565],\n",
       "         [ 0.2850],\n",
       "         [-0.1917],\n",
       "         [ 0.2090],\n",
       "         [-0.7110],\n",
       "         [-0.4760],\n",
       "         [-2.2785],\n",
       "         [-2.7839],\n",
       "         [ 1.4080],\n",
       "         [-2.7538],\n",
       "         [ 0.8272],\n",
       "         [-1.2625],\n",
       "         [ 1.0311],\n",
       "         [-1.1250],\n",
       "         [-0.9827],\n",
       "         [-0.1242],\n",
       "         [ 1.4790],\n",
       "         [ 0.5899],\n",
       "         [-0.1593],\n",
       "         [-0.2475],\n",
       "         [ 0.7443],\n",
       "         [ 0.7036],\n",
       "         [ 0.1306],\n",
       "         [-0.8472],\n",
       "         [ 0.5299],\n",
       "         [-0.8480],\n",
       "         [-0.2363],\n",
       "         [ 0.0789],\n",
       "         [-0.7635],\n",
       "         [ 0.8476],\n",
       "         [-0.2516],\n",
       "         [-1.0596],\n",
       "         [ 0.5161],\n",
       "         [-0.4635],\n",
       "         [-1.2969],\n",
       "         [ 0.7626],\n",
       "         [ 0.3091],\n",
       "         [-0.2611],\n",
       "         [-0.5435],\n",
       "         [-0.7720],\n",
       "         [-1.7273],\n",
       "         [ 0.6817],\n",
       "         [-0.6991],\n",
       "         [ 0.6553],\n",
       "         [-0.8341],\n",
       "         [-1.7154],\n",
       "         [ 0.4859],\n",
       "         [-0.2933],\n",
       "         [-1.2694],\n",
       "         [-0.1433],\n",
       "         [ 0.9608],\n",
       "         [-0.1946],\n",
       "         [-1.2008],\n",
       "         [ 1.6482],\n",
       "         [ 0.3344],\n",
       "         [-0.9878],\n",
       "         [ 0.5632],\n",
       "         [ 0.7302],\n",
       "         [-0.6242],\n",
       "         [-0.5574],\n",
       "         [ 0.6164],\n",
       "         [ 0.5761],\n",
       "         [ 0.6418],\n",
       "         [-0.1734],\n",
       "         [ 0.5284],\n",
       "         [-1.4486],\n",
       "         [ 0.6077],\n",
       "         [ 0.5598],\n",
       "         [ 0.4934],\n",
       "         [-0.3721],\n",
       "         [-0.0991],\n",
       "         [-0.9200],\n",
       "         [ 0.5936],\n",
       "         [ 0.0696],\n",
       "         [-1.3983],\n",
       "         [ 1.1249],\n",
       "         [-0.3164],\n",
       "         [-0.6347],\n",
       "         [-0.7354],\n",
       "         [ 0.4323],\n",
       "         [-0.3369],\n",
       "         [ 0.8003],\n",
       "         [-2.5563],\n",
       "         [ 1.1910],\n",
       "         [-0.0647],\n",
       "         [ 0.5702],\n",
       "         [-0.0458],\n",
       "         [-0.2128],\n",
       "         [ 0.5138],\n",
       "         [ 1.1420],\n",
       "         [ 0.1118],\n",
       "         [ 0.2174],\n",
       "         [ 1.1929],\n",
       "         [-0.3623],\n",
       "         [ 1.3847],\n",
       "         [-0.0365],\n",
       "         [ 0.1581],\n",
       "         [ 0.5923],\n",
       "         [ 0.5838],\n",
       "         [-0.8108],\n",
       "         [ 1.3285],\n",
       "         [ 1.2175],\n",
       "         [-0.3995],\n",
       "         [-0.3820],\n",
       "         [-1.5060],\n",
       "         [ 0.8345],\n",
       "         [-0.9263],\n",
       "         [ 0.3198],\n",
       "         [-0.9400],\n",
       "         [ 0.5726]]),\n",
       " tensor([[-1.2222e-01, -9.9250e-01, -1.2855e+00],\n",
       "         [-2.2270e-01, -9.7489e-01, -2.0412e+00],\n",
       "         [-3.5320e-01, -9.3555e-01, -2.7281e+00],\n",
       "         [-5.1124e-01, -8.5944e-01, -3.5129e+00],\n",
       "         [-6.7750e-01, -7.3552e-01, -4.1546e+00],\n",
       "         [-8.2667e-01, -5.6269e-01, -4.5759e+00],\n",
       "         [-9.4257e-01, -3.3401e-01, -5.1417e+00],\n",
       "         [-9.9763e-01, -6.8863e-02, -5.4327e+00],\n",
       "         [-9.7932e-01,  2.0234e-01, -5.4533e+00],\n",
       "         [-8.9561e-01,  4.4483e-01, -5.1449e+00],\n",
       "         [-7.6215e-01,  6.4740e-01, -4.8634e+00],\n",
       "         [-5.9843e-01,  8.0117e-01, -4.5018e+00],\n",
       "         [-4.4529e-01,  8.9539e-01, -3.6010e+00],\n",
       "         [-3.1174e-01,  9.5017e-01, -2.8895e+00],\n",
       "         [-2.0532e-01,  9.7870e-01, -2.2047e+00],\n",
       "         [-1.2941e-01,  9.9159e-01, -1.5404e+00],\n",
       "         [-9.0604e-02,  9.9589e-01, -7.8085e-01],\n",
       "         [-9.8803e-02,  9.9511e-01,  1.6471e-01],\n",
       "         [-1.3320e-01,  9.9109e-01,  6.9272e-01],\n",
       "         [-2.0682e-01,  9.7838e-01,  1.4944e+00],\n",
       "         [-3.0596e-01,  9.5204e-01,  2.0526e+00],\n",
       "         [-4.3130e-01,  9.0221e-01,  2.6995e+00],\n",
       "         [-5.7842e-01,  8.1574e-01,  3.4171e+00],\n",
       "         [-7.2121e-01,  6.9271e-01,  3.7754e+00],\n",
       "         [-8.5762e-01,  5.1428e-01,  4.5016e+00],\n",
       "         [-9.5434e-01,  2.9872e-01,  4.7362e+00],\n",
       "         [-9.9810e-01,  6.1606e-02,  4.8342e+00],\n",
       "         [-9.8435e-01, -1.7620e-01,  4.7754e+00],\n",
       "         [-9.1427e-01, -4.0510e-01,  4.7993e+00],\n",
       "         [-8.0742e-01, -5.8998e-01,  4.2788e+00],\n",
       "         [-6.8368e-01, -7.2978e-01,  3.7394e+00],\n",
       "         [-5.6346e-01, -8.2614e-01,  3.0845e+00],\n",
       "         [-4.7091e-01, -8.8218e-01,  2.1649e+00],\n",
       "         [-3.9809e-01, -9.1735e-01,  1.6179e+00],\n",
       "         [-3.6210e-01, -9.3214e-01,  7.7814e-01],\n",
       "         [-3.5559e-01, -9.3464e-01,  1.3956e-01],\n",
       "         [-3.8235e-01, -9.2402e-01, -5.7592e-01],\n",
       "         [-4.3891e-01, -8.9853e-01, -1.2408e+00],\n",
       "         [-5.2294e-01, -8.5237e-01, -1.9182e+00],\n",
       "         [-6.2591e-01, -7.7990e-01, -2.5200e+00],\n",
       "         [-7.4493e-01, -6.6714e-01, -3.2827e+00],\n",
       "         [-8.5556e-01, -5.1771e-01, -3.7239e+00],\n",
       "         [-9.4139e-01, -3.3732e-01, -4.0021e+00],\n",
       "         [-9.9224e-01, -1.2434e-01, -4.3880e+00],\n",
       "         [-9.9523e-01,  9.7534e-02, -4.4471e+00],\n",
       "         [-9.5371e-01,  3.0072e-01, -4.1552e+00],\n",
       "         [-8.7398e-01,  4.8597e-01, -4.0404e+00],\n",
       "         [-7.6490e-01,  6.4415e-01, -3.8489e+00],\n",
       "         [-6.4563e-01,  7.6365e-01, -3.3807e+00],\n",
       "         [-5.2790e-01,  8.4931e-01, -2.9145e+00],\n",
       "         [-4.2550e-01,  9.0496e-01, -2.3321e+00],\n",
       "         [-3.3918e-01,  9.4072e-01, -1.8695e+00],\n",
       "         [-2.7908e-01,  9.6027e-01, -1.2640e+00],\n",
       "         [-2.3981e-01,  9.7082e-01, -8.1344e-01],\n",
       "         [-2.3409e-01,  9.7221e-01, -1.1773e-01],\n",
       "         [-2.5320e-01,  9.6741e-01,  3.9413e-01],\n",
       "         [-2.9879e-01,  9.5432e-01,  9.4881e-01],\n",
       "         [-3.7081e-01,  9.2871e-01,  1.5290e+00],\n",
       "         [-4.5996e-01,  8.8794e-01,  1.9614e+00],\n",
       "         [-5.7458e-01,  8.1845e-01,  2.6829e+00],\n",
       "         [-7.0397e-01,  7.1023e-01,  3.3777e+00],\n",
       "         [-8.2005e-01,  5.7230e-01,  3.6104e+00],\n",
       "         [-9.1379e-01,  4.0620e-01,  3.8203e+00],\n",
       "         [-9.7494e-01,  2.2247e-01,  3.8787e+00],\n",
       "         [-9.9964e-01,  2.6768e-02,  3.9516e+00],\n",
       "         [-9.8564e-01, -1.6888e-01,  3.9292e+00],\n",
       "         [-9.3589e-01, -3.5228e-01,  3.8064e+00],\n",
       "         [-8.5995e-01, -5.1038e-01,  3.5124e+00],\n",
       "         [-7.7939e-01, -6.2654e-01,  2.8296e+00],\n",
       "         [-7.0392e-01, -7.1028e-01,  2.2558e+00],\n",
       "         [-6.4231e-01, -7.6645e-01,  1.6679e+00],\n",
       "         [-5.9569e-01, -8.0321e-01,  1.1875e+00],\n",
       "         [-5.6950e-01, -8.2199e-01,  6.4456e-01],\n",
       "         [-5.6684e-01, -8.2383e-01,  6.4653e-02],\n",
       "         [-5.9181e-01, -8.0608e-01, -6.1256e-01],\n",
       "         [-6.4028e-01, -7.6814e-01, -1.2314e+00],\n",
       "         [-7.0666e-01, -7.0755e-01, -1.7981e+00],\n",
       "         [-7.8145e-01, -6.2397e-01, -2.2442e+00],\n",
       "         [-8.5537e-01, -5.1802e-01, -2.5856e+00],\n",
       "         [-9.2347e-01, -3.8367e-01, -3.0154e+00],\n",
       "         [-9.7588e-01, -2.1829e-01, -3.4741e+00],\n",
       "         [-9.9918e-01, -4.0569e-02, -3.5896e+00],\n",
       "         [-9.8962e-01,  1.4374e-01, -3.6964e+00],\n",
       "         [-9.4501e-01,  3.2704e-01, -3.7786e+00],\n",
       "         [-8.6994e-01,  4.9316e-01, -3.6510e+00],\n",
       "         [-7.7567e-01,  6.3114e-01, -3.3461e+00],\n",
       "         [-6.7315e-01,  7.3950e-01, -2.9861e+00],\n",
       "         [-5.7409e-01,  8.1879e-01, -2.5394e+00],\n",
       "         [-4.8772e-01,  8.7300e-01, -2.0403e+00],\n",
       "         [-4.3151e-01,  9.0211e-01, -1.2663e+00],\n",
       "         [-4.0091e-01,  9.1612e-01, -6.7316e-01],\n",
       "         [-4.0350e-01,  9.1498e-01,  5.6671e-02],\n",
       "         [-4.3591e-01,  8.9999e-01,  7.1414e-01],\n",
       "         [-4.9868e-01,  8.6679e-01,  1.4205e+00],\n",
       "         [-5.8125e-01,  8.1372e-01,  1.9639e+00],\n",
       "         [-6.7827e-01,  7.3481e-01,  2.5028e+00],\n",
       "         [-7.7271e-01,  6.3475e-01,  2.7539e+00],\n",
       "         [-8.5710e-01,  5.1516e-01,  2.9300e+00],\n",
       "         [-9.3419e-01,  3.5677e-01,  3.5276e+00],\n",
       "         [-9.8199e-01,  1.8891e-01,  3.4951e+00],\n",
       "         [-1.0000e+00,  2.0086e-03,  3.7609e+00],\n",
       "         [-9.8444e-01, -1.7573e-01,  3.5731e+00],\n",
       "         [-9.3714e-01, -3.4894e-01,  3.5959e+00],\n",
       "         [-8.7043e-01, -4.9229e-01,  3.1655e+00],\n",
       "         [-7.9780e-01, -6.0292e-01,  2.6489e+00],\n",
       "         [-7.2755e-01, -6.8606e-01,  2.1780e+00],\n",
       "         [-6.5974e-01, -7.5150e-01,  1.8853e+00],\n",
       "         [-6.0516e-01, -7.9611e-01,  1.4102e+00],\n",
       "         [-5.7328e-01, -8.1936e-01,  7.8923e-01],\n",
       "         [-5.6763e-01, -8.2329e-01,  1.3758e-01],\n",
       "         [-5.8269e-01, -8.1270e-01, -3.6824e-01],\n",
       "         [-6.1757e-01, -7.8652e-01, -8.7222e-01],\n",
       "         [-6.7264e-01, -7.3997e-01, -1.4425e+00],\n",
       "         [-7.4731e-01, -6.6448e-01, -2.1246e+00],\n",
       "         [-8.2555e-01, -5.6433e-01, -2.5434e+00],\n",
       "         [-9.0264e-01, -4.3040e-01, -3.0939e+00],\n",
       "         [-9.6315e-01, -2.6897e-01, -3.4521e+00],\n",
       "         [-9.9593e-01, -9.0100e-02, -3.6420e+00],\n",
       "         [-9.9491e-01,  1.0081e-01, -3.8241e+00],\n",
       "         [-9.6049e-01,  2.7833e-01, -3.6214e+00],\n",
       "         [-8.9845e-01,  4.3908e-01, -3.4504e+00],\n",
       "         [-8.1471e-01,  5.7987e-01, -3.2800e+00],\n",
       "         [-7.2693e-01,  6.8671e-01, -2.7677e+00],\n",
       "         [-6.4248e-01,  7.6630e-01, -2.3222e+00],\n",
       "         [-5.6516e-01,  8.2498e-01, -1.9420e+00],\n",
       "         [-5.1429e-01,  8.5761e-01, -1.2089e+00],\n",
       "         [-4.9186e-01,  8.7068e-01, -5.1930e-01],\n",
       "         [-4.9597e-01,  8.6834e-01,  9.4548e-02],\n",
       "         [-5.2453e-01,  8.5139e-01,  6.6428e-01],\n",
       "         [-5.7411e-01,  8.1878e-01,  1.1870e+00],\n",
       "         [-6.3547e-01,  7.7213e-01,  1.5420e+00],\n",
       "         [-7.1720e-01,  6.9686e-01,  2.2234e+00],\n",
       "         [-8.0272e-01,  5.9636e-01,  2.6411e+00],\n",
       "         [-8.8717e-01,  4.6144e-01,  3.1867e+00],\n",
       "         [-9.5257e-01,  3.0433e-01,  3.4077e+00],\n",
       "         [-9.9017e-01,  1.3985e-01,  3.3786e+00],\n",
       "         [-9.9930e-01, -3.7501e-02,  3.5564e+00],\n",
       "         [-9.7767e-01, -2.1014e-01,  3.4842e+00],\n",
       "         [-9.3286e-01, -3.6025e-01,  3.1362e+00],\n",
       "         [-8.7237e-01, -4.8884e-01,  2.8445e+00],\n",
       "         [-8.0098e-01, -5.9869e-01,  2.6220e+00],\n",
       "         [-7.3233e-01, -6.8095e-01,  2.1438e+00],\n",
       "         [-6.8097e-01, -7.3231e-01,  1.4530e+00],\n",
       "         [-6.3773e-01, -7.7026e-01,  1.1510e+00],\n",
       "         [-6.1341e-01, -7.8977e-01,  6.2345e-01],\n",
       "         [-6.1802e-01, -7.8616e-01, -1.1705e-01],\n",
       "         [-6.4217e-01, -7.6656e-01, -6.2219e-01],\n",
       "         [-6.8289e-01, -7.3052e-01, -1.0876e+00],\n",
       "         [-7.4342e-01, -6.6883e-01, -1.7291e+00],\n",
       "         [-8.1567e-01, -5.7852e-01, -2.3143e+00],\n",
       "         [-8.8508e-01, -4.6543e-01, -2.6558e+00],\n",
       "         [-9.4335e-01, -3.3179e-01, -2.9184e+00],\n",
       "         [-9.8300e-01, -1.8360e-01, -3.0710e+00],\n",
       "         [-9.9974e-01, -2.2914e-02, -3.2347e+00],\n",
       "         [-9.9080e-01,  1.3530e-01, -3.1726e+00],\n",
       "         [-9.5530e-01,  2.9565e-01, -3.2885e+00],\n",
       "         [-9.0092e-01,  4.3399e-01, -2.9756e+00],\n",
       "         [-8.3798e-01,  5.4570e-01, -2.5661e+00],\n",
       "         [-7.7672e-01,  6.2985e-01, -2.0828e+00],\n",
       "         [-7.2161e-01,  6.9230e-01, -1.6662e+00],\n",
       "         [-6.8019e-01,  7.3303e-01, -1.1619e+00],\n",
       "         [-6.5223e-01,  7.5802e-01, -7.5011e-01],\n",
       "         [-6.4871e-01,  7.6103e-01, -9.2548e-02],\n",
       "         [-6.6711e-01,  7.4496e-01,  4.8867e-01],\n",
       "         [-6.9772e-01,  7.1637e-01,  8.3765e-01],\n",
       "         [-7.5088e-01,  6.6044e-01,  1.5437e+00],\n",
       "         [-8.1282e-01,  5.8252e-01,  1.9915e+00],\n",
       "         [-8.7509e-01,  4.8395e-01,  2.3332e+00],\n",
       "         [-9.3019e-01,  3.6708e-01,  2.5859e+00],\n",
       "         [-9.7376e-01,  2.2756e-01,  2.9260e+00],\n",
       "         [-9.9702e-01,  7.7184e-02,  3.0462e+00],\n",
       "         [-9.9648e-01, -8.3846e-02,  3.2241e+00],\n",
       "         [-9.7434e-01, -2.2506e-01,  2.8612e+00],\n",
       "         [-9.3212e-01, -3.6214e-01,  2.8711e+00],\n",
       "         [-8.7756e-01, -4.7947e-01,  2.5898e+00],\n",
       "         [-8.1629e-01, -5.7764e-01,  2.3157e+00],\n",
       "         [-7.5861e-01, -6.5154e-01,  1.8756e+00],\n",
       "         [-7.1276e-01, -7.0140e-01,  1.3550e+00],\n",
       "         [-6.8027e-01, -7.3296e-01,  9.0604e-01],\n",
       "         [-6.6070e-01, -7.5065e-01,  5.2761e-01],\n",
       "         [-6.6140e-01, -7.5004e-01, -1.8606e-02],\n",
       "         [-6.8172e-01, -7.3162e-01, -5.4852e-01],\n",
       "         [-7.1458e-01, -6.9956e-01, -9.1830e-01],\n",
       "         [-7.6490e-01, -6.4415e-01, -1.4973e+00],\n",
       "         [-8.1892e-01, -5.7391e-01, -1.7727e+00],\n",
       "         [-8.7718e-01, -4.8016e-01, -2.2086e+00],\n",
       "         [-9.3102e-01, -3.6496e-01, -2.5450e+00],\n",
       "         [-9.7202e-01, -2.3488e-01, -2.7299e+00],\n",
       "         [-9.9538e-01, -9.6020e-02, -2.8185e+00],\n",
       "         [-9.9852e-01,  5.4412e-02, -3.0121e+00],\n",
       "         [-9.8143e-01,  1.9185e-01, -2.7721e+00],\n",
       "         [-9.5070e-01,  3.1012e-01, -2.4456e+00],\n",
       "         [-9.0940e-01,  4.1593e-01, -2.2729e+00],\n",
       "         [-8.6287e-01,  5.0543e-01, -2.0182e+00],\n",
       "         [-8.1206e-01,  5.8358e-01, -1.8651e+00],\n",
       "         [-7.7236e-01,  6.3518e-01, -1.3022e+00],\n",
       "         [-7.4084e-01,  6.7168e-01, -9.6477e-01],\n",
       "         [-7.2681e-01,  6.8684e-01, -4.1304e-01],\n",
       "         [-7.2547e-01,  6.8825e-01, -3.8912e-02],\n",
       "         [-7.4456e-01,  6.6755e-01,  5.6316e-01]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:39.148982Z",
     "start_time": "2025-09-03T08:53:39.133756Z"
    }
   },
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1342.1135451139558"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T08:53:39.158653Z",
     "start_time": "2025-09-03T08:53:39.156358Z"
    }
   },
   "source": [
    "#优势函数\n",
    "def get_advantages(deltas):\n",
    "    advantages = []\n",
    "\n",
    "    #反向遍历deltas\n",
    "    s = 0.0\n",
    "    for delta in deltas[::-1]:\n",
    "        s = 0.9 * 0.9 * s + delta\n",
    "        advantages.append(s)\n",
    "\n",
    "    #逆序\n",
    "    advantages.reverse()\n",
    "    return advantages\n",
    "\n",
    "\n",
    "get_advantages(range(5))"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[5.43839184, 6.7140640000000005, 7.0544, 6.24, 4.0]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "executionInfo": {
     "elapsed": 8251,
     "status": "ok",
     "timestamp": 1650011468229,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "BQXVYW2T_DcQ",
    "scrolled": false,
    "ExecuteTime": {
     "end_time": "2025-09-03T08:54:28.732425Z",
     "start_time": "2025-09-03T08:53:39.166290Z"
    }
   },
   "source": [
    "def train():\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "    optimizer_td = torch.optim.Adam(model_td.parameters(), lr=5e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #玩N局游戏,每局游戏训练M次\n",
    "    for epoch in range(3000):\n",
    "        #玩一局游戏,得到数据\n",
    "        #states -> [b, 3]\n",
    "        #rewards -> [b, 1]\n",
    "        #actions -> [b, 1]\n",
    "        #next_states -> [b, 3]\n",
    "        #overs -> [b, 1]\n",
    "        states, rewards, actions, next_states, overs = get_data()\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        rewards = (rewards + 8) / 8\n",
    "\n",
    "        #计算values和targets\n",
    "        #[b, 3] -> [b, 1]\n",
    "        values = model_td(states)\n",
    "\n",
    "        #[b, 3] -> [b, 1]\n",
    "        targets = model_td(next_states).detach()\n",
    "        targets = targets * 0.98\n",
    "        targets *= (1 - overs)\n",
    "        targets += rewards\n",
    "\n",
    "        #计算优势,这里的advantages有点像是策略梯度里的reward_sum\n",
    "        #只是这里计算的不是reward,而是target和value的差\n",
    "        #[b, 1]\n",
    "        deltas = (targets - values).squeeze(dim=1).tolist()\n",
    "        advantages = get_advantages(deltas)\n",
    "        advantages = torch.FloatTensor(advantages).reshape(-1, 1)\n",
    "\n",
    "        #取出每一步动作的概率\n",
    "        #[b, 3] -> [b, 1],[b, 1]\n",
    "        mu, std = model(states)\n",
    "        #[b, 1]\n",
    "        old_probs = torch.distributions.Normal(mu, std)\n",
    "        old_probs = old_probs.log_prob(actions).exp().detach()\n",
    "\n",
    "        #每批数据反复训练10次\n",
    "        for _ in range(10):\n",
    "            #重新计算每一步动作的概率\n",
    "            #[b, 3] -> [b, 1],[b, 1]\n",
    "            mu, std = model(states)\n",
    "            #[b, 1]\n",
    "            new_probs = torch.distributions.Normal(mu, std)\n",
    "            new_probs = new_probs.log_prob(actions).exp()\n",
    "\n",
    "            #求出概率的变化\n",
    "            #[b, 1] - [b, 1] -> [b, 1]\n",
    "            ratios = new_probs / old_probs\n",
    "\n",
    "            #计算截断的和不截断的两份loss,取其中小的\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr1 = ratios * advantages\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages\n",
    "\n",
    "            loss = -torch.min(surr1, surr2)\n",
    "            loss = loss.mean()\n",
    "\n",
    "            #重新计算value,并计算时序差分loss\n",
    "            values = model_td(states)\n",
    "            loss_td = loss_fn(values, targets)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            optimizer_td.zero_grad()\n",
    "            loss_td.backward()\n",
    "            optimizer_td.step()\n",
    "\n",
    "        if epoch % 200 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, test_result)\n",
    "\n",
    "\n",
    "train()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1192.465293061399\n",
      "200 -1035.7373974780714\n",
      "400 -874.0707008872608\n",
      "600 -852.1135518376617\n",
      "800 -810.581807767871\n",
      "1000 -635.3559709568706\n",
      "1200 -635.8518121821742\n",
      "1400 -709.5810282565767\n",
      "1600 -426.91806455912376\n",
      "1800 -448.4908949721399\n",
      "2000 -496.43754633029346\n",
      "2200 -479.32443300366623\n",
      "2400 -296.55205779840895\n",
      "2600 -647.1533822599787\n",
      "2800 -487.29966665385456\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "test(play=True)"
   ],
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-125.8941135769733"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
